# USAGE
# python yolo.py --image images/baggage_claim.jpg --yolo yolo-coco

# import the necessary packages
import numpy as np
import time
import cv2
import os,sys


class YOLONet(object):
    """docstring for YOLONet"""
    def __init__(self, pWeightPth,pCfgPth,pNamePth,pConfidence = 0.5,pThreshold = 0.3):
        super(YOLONet, self).__init__()
        self.weightsPath = pWeightPth
        self.configPath = pCfgPth
        self.labelsPath = pNamePth
        self.confidence = pConfidence
        self.threshold = pThreshold
        self.net = None
        self.layerNames = None
        self.LABELS = []
        self.COLORS = []
        self.loadYoloNet()
    def loadYoloNet(self):
        self.LABELS = open(self.labelsPath).read().strip().split("\n")

        # initialize a list of colors to represent each possible class label
        self.COLORS = np.random.randint(0, 255, size=(len(self.LABELS), 3),
            dtype="uint8")
        print("[INFO] loading YOLO from disk...")
        cv2versions = cv2.__version__.split('.')
        print(cv2versions)
        print(self.configPath)
        print(self.weightsPath)
        if(int(cv2versions[0]) == 4 and int(cv2versions[1]) >=1 and int(cv2versions[2]) >= 2):
            print('load version4')
            self.net = cv2.dnn_DetectionModel(self.configPath, self.weightsPath)
        elif (int(cv2versions[0]) == 5 and int(cv2versions[1]) >=1 and int(cv2versions[2]) >= 2):
            print('load version5')
            self.net = cv2.dnn_DetectionModel(self.configPath, self.weightsPath)
        else:
            print('load version3')
            self.net = cv2.dnn.readNetFromDarknet(self.configPath, self.weightsPath)
        print('load end')
        ln = self.net.getLayerNames()
        self.layerNames = [ln[i[0] - 1] for i in self.net.getUnconnectedOutLayers()]
        np.random.seed(42)
        
    def fandObjects(self,image):
        (H, W) = image.shape[:2]
        blob = cv2.dnn.blobFromImage(image, 1 / 255.0, (416, 416),
        swapRB=True, crop=False)
        self.net.setInput(blob)
        start = time.time()
        layerOutputs = self.net.forward(self.layerNames)
        end = time.time()
        # show timing information on YOLO
        print("[INFO] YOLO took {:.6f} seconds".format(end - start))
        boxes = []
        confidences = []
        classIDs = []

        # loop over each of the layer outputs
        for output in layerOutputs:
            # loop over each of the detections
            for detection in output:
                # extract the class ID and confidence (i.e., probability) of
                # the current object detection
                scores = detection[5:]
                classID = np.argmax(scores)
                confidence = scores[classID]

                # filter out weak predictions by ensuring the detected
                # probability is greater than the minimum probability
                if confidence > self.confidence:
                    # scale the bounding box coordinates back relative to the
                    # size of the image, keeping in mind that YOLO actually
                    # returns the center (x, y)-coordinates of the bounding
                    # box followed by the boxes' width and height
                    box = detection[0:4] * np.array([W, H, W, H])
                    (centerX, centerY, width, height) = box.astype("int")

                    # use the center (x, y)-coordinates to derive the top and
                    # and left corner of the bounding box
                    x = int(centerX - (width / 2))
                    y = int(centerY - (height / 2))

                    # update our list of bounding box coordinates, confidences,
                    # and class IDs
                    boxes.append([x, y, int(width), int(height)])
                    confidences.append(float(confidence))
                    classIDs.append(classID)

        # apply non-maxima suppression to suppress weak, overlapping bounding
        # boxes
        idxs = cv2.dnn.NMSBoxes(boxes, confidences,self.confidence,
            self.threshold)

        # ensure at least one detection exists
        outboxes = {}
        if len(idxs) > 0:
            # loop over the indexes we are keeping
            for i in idxs.flatten():
                # extract the bounding box coordinates
                (x, y) = (boxes[i][0], boxes[i][1])
                (w, h) = (boxes[i][2], boxes[i][3])
                print(boxes[i])
                outboxes[int(i)] = {'x':x,'y':y,'w':w,'h':h,'t':self.LABELS[classIDs[i]],'s':confidences[i],'imgW':W,'imgH':H}
                # draw a bounding box rectangle and label on the image
                color = [int(c) for c in self.COLORS[classIDs[i]]]
                cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
                text = "{}: {:.4f}".format(self.LABELS[classIDs[i]], confidences[i])
                cv2.putText(image, text, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX,
                    0.5, color, 2)

        return image,outboxes

def getDandCPoint(yoloobj):
    olist = []
    olistIndex = []
    clist = []
    dlist = []
    for k,v in yoloobj.items():
        if v['t'] == 'c':
            clist.append(v)
        elif v['t'] == 'd':
            dlist.append(v)
    if len(clist) != len(dlist):
        return None
    sorted_clist = sorted(clist, key=lambda x: x['y'])
    sorted_dlist = sorted(dlist, key=lambda x: x['y'])
    for i,v in enumerate(sorted_dlist):
        olist.append([v,sorted_clist[i]])
    return olist

def main(imgPth):
    labelsPath = os.getcwd() + os.sep +'yolo-net' +os.sep + 'my.names'
    weightsPath = os.getcwd() + os.sep +'yolo-net'+ os.sep + "yolov3.weights"
    configPath = os.getcwd() + os.sep +'yolo-net' + os.sep + "yolov3.cfg"
    inImgPth = imgPth 
    # if not inImgPth:
    #     inImgPth = os.getcwd() + os.sep + 'images' + os.sep +'IMG_1614.jpg'
    yolonet = YOLONet(weightsPath, configPath, labelsPath)
    imImg = cv2.imread(inImgPth)
    image,outboxes = yolonet.fandObjects(imImg)
    print(outboxes)
    for k,v in outboxes.items():
        print(k,v)
    olist = getDandCPoint(outboxes)
    print(olist,len(olist),len(olist[0]),len(olist[1]))
    for i,v in enumerate(olist):
        print(i,v)
    # for i,v in enumerate(olist[1]):
    #     print(i,v)
    scalefloat = 1.0
    w = int(image.shape[1]*scalefloat)
    h = int(image.shape[0]*scalefloat)
    resizeimg = cv2.resize(image, (w,h), interpolation = cv2.INTER_AREA)

    cv2.imshow("Image", resizeimg)
    cv2.moveWindow('Image',500,0)
    cv2.waitKey(0)



if __name__ == '__main__':
    if len(sys.argv) > 1:
        main(sys.argv[1])
    else:
        print('please input image path')